# -*- coding:utf-8 -*-
"""
Created on 2011-2-6

@author: summit
"""
import vtk
import math
import wx
import platform

def getImageScalarRange(image):
    """
    get the minimum and maximum value of an image based on it's datatype
    """
    x0, x1 = image.GetScalarRange()

    scalarType = image.GetScalarTypeAsString()
    if scalarType == "unsigned char":
        return 0, 255
    if scalarType == "unsigned short": 
        if x1 > 4095:
            return 0, (2**16)-1
        return 0, 4095

    return x0,x1

def get_histogram(image, maxval = 0, minval = 0):
    """
    Return the histogram of the image as a list of floats
    """
    accu = vtk.vtkImageAccumulate()
    accu.SetInputConnection(image.GetProducerPort())

    if maxval == 0:
        x0, x1 = getImageScalarRange(image)
        x1 = int(math.floor(x1))
        x0 = int(math.ceil(x0))
    else:
        x0, x1 = (int(math.ceil(minval)),int(math.floor(maxval)))

    accu.SetComponentExtent(0, x1 - x0, 0, 0, 0, 0)
    accu.SetComponentOrigin(x0, 0, 0)
    #accu.SetComponentSpacing((x1 - x0) / 255.0, 0, 0)
    accu.SetComponentSpacing(1, 0, 0)
    accu.Update() 
    data = accu.GetOutput()
    
    values = []
    x0, x1, y0, y1, z0, z1 = data.GetWholeExtent()

    for i in range(x0, x1 + 1):
        c = data.GetScalarComponentAsDouble(i, 0, 0, 0)
        values.append(c)
    return values

def histogram(imagedata, colorTransferFunction = None, bg = (200, 200, 200), logarithmic = 1, \
                ignore_border = 0, lower = 0, upper = 0, percent_only = 0, maxval = 255, minval = 0):
    """
    Draw a histogram of a volume
    """
    values = get_histogram(imagedata,maxval,minval)
    sum = 0
    xoffset = 10
    sumth = 0
    percent = 0
    for i, c in enumerate(values):
        sum += c
        if (lower or upper):
            if i >= lower and i <= upper:
                sumth += c
    retvals = values[: ]
    print "lower = %d, upper = %d, total amount of %d values" \
                    % (lower, upper, len(values))
    if sumth:
        percent = (float(sumth) / sum)
    if ignore_border:
        ma = max(values[5:])
        mi = min(values[:-5])
        n = len(values)
        for i in range(0, 5):
            values[i] = ma
        for i in range(n - 5, n):
            values[i] = mi
            
    for i, value in enumerate(values):
        if value == 0:
            values[i] = 1
    if logarithmic:
        values = map(math.log, values)  # log will scale the bmp bigger

    m = max(values)
    scale = 150.0 / m   # scale the value by bmp size
    values = [x * scale for x in values]
    w = 256
    x1 = max(values)
    w += xoffset + 5
    
    diff = 0
    if colorTransferFunction:
        diff = 30
    if percent:
        diff += 20
    print ("Creating a %dx%d bitmap for histogram" )% (int(w), int(x1) + diff)
        
    # Add an offset of 15 for the percentage text
    bmp = wx.EmptyBitmap(int(w), int(x1) + diff)
    dc = wx.MemoryDC()
    dc.SelectObject(bmp)
    dc.BeginDrawing()
    
    blackpen = wx.Pen((0, 0, 0), 1)
    graypen = wx.Pen((80, 80, 80), 1)
    whitepen = wx.Pen((255, 255, 255), 1)
    
    if platform.system()!="Darwin":
        dc.SetBackground(wx.Brush(bg))
        dc.Clear()
    dc.SetBrush(wx.Brush(wx.Colour(200, 200, 200)))
    dc.DrawRectangle(0, 0, 256+xoffset+1, 151)
    
    if not logarithmic:
        points = range(1, 150, 150 / 8)
    else:
        points = [4, 8, 16, 28, 44, 64, 88, 116, 148]   # 8 scale
        points = [p + 2 for p in points]
        #points.reverse()
        
    for i in points:    # draw scale
        y = 151-i
        dc.SetPen(blackpen)
        dc.DrawLine(0, y, 5, y)
        dc.SetPen(whitepen)
        dc.DrawLine(0, y - 1, 5, y - 1)
    
    d = (len(values) - 1) / 255.0
    dc.SetPen(blackpen)
    dc.DrawLine(xoffset-1, 0, xoffset-1, 151)
    for i in range(0, 256): #    draw histogram
        c = values[int(i * d)]
        if c:
            #c2 = values[int((i * d) + d)]
            dc.SetPen(graypen)
            dc.DrawLine(xoffset + i, x1, xoffset + i, x1 - c)
            #dc.SetPen(blackpen)
            #dc.DrawLine(xoffset + i, x1 - c, xoffset + i, x1 - c-2)
            
    if colorTransferFunction:
        for i in range(minval, maxval + d, d):
            val = [0, 0, 0]
            colorTransferFunction.GetColor(i, val)
            r, g, b = val
            r = int(r * 255)
            b = int(b * 255)
            g = int(g * 255)
            dc.SetPen(wx.Pen(wx.Colour(r, g, b), 1))
            dc.DrawLine(xoffset + i, x1 + 8, xoffset + i, x1 + 30)
        dc.SetPen(whitepen)
        dc.SetFont(wx.Font(8, wx.SWISS, wx.NORMAL, wx.NORMAL))
        dc.DrawText(str(int(maxval)), xoffset+maxval-25, x1 + 10)
    else:
        print "Got no ctf for histogram"
    
    dc.EndDrawing()
    dc.SelectObject(wx.NullBitmap)
    dc = None     
    return bmp, percent, retvals, xoffset

if __name__ == "__main__":
    from vtk.util.misc import vtkGetDataRoot
    app = wx.PySimpleApp()
    VTK_DATA_ROOT = vtkGetDataRoot()
    
    v16 = vtk.vtkVolume16Reader()
    v16.SetDataDimensions(64, 64)
    v16.SetDataByteOrderToLittleEndian()
    v16.SetFilePrefix(VTK_DATA_ROOT + "/Data/headsq/quarter")
    v16.SetImageRange(1, 93)
    v16.SetDataSpacing(3.2, 3.2, 1.5)
    v16.Update()
    
    bmp, percent, retvals, xoffset = histogram(v16.GetOutput(), logarithmic=0)
    bmp.SaveFile("C://hello1.bmp", wx.BITMAP_TYPE_BMP )
    print percent, retvals, xoffset
    